import json
import random

import numpy as np
import pandas as pd
import torch

from ModularUtils.ControllerConstants import map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_fake_distribution, get_generated_labels
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import getdoKey, asKey
from ModularUtils.FunctionsDistribution import match_with_true_dist, calculate_TVD, calculate_KL
from ModularUtils.FunctionsTraining import save_results
from ModularUtils.Functions_Plot_Results import plot_saved_results
from Train_By_Components.Causal_TrainGraph import set_trainGraph
from Train_By_Components.Synthetic_TrainByComp import get_intv_dist


def get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff, kl_diff):
    feat="feature"

    fake_expected_dist={}
    true_expected_dist = {}
    for query in Exp.interv_queries:

        # if bool(set(query["obs"]) & set(cur_mechs)) ==False:
        #     continue

        compare_Var = list(query["intervs"][0].keys())  #getting the intervened variables

        if len(compare_Var)==0:
            continue

        query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
        obs_dist = get_intv_dist(Exp, compare_Var , dict({}), query_str, load_scm=1) # getting the obs distribution of intv variables

        # {"obs": obs_vars, "intervs": key_val, "expr": intervention["expr"]}
        tvd_sum = 0
        kl_sum = 0
        for intv_key in query["intervs"]:

            query_string= getdoKey(query["obs"], intv_key)
            true_dist= get_intv_dist(Exp, query["obs"], intv_key, query_string, load_scm=1)

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, query["obs"], Exp.Synthetic_Sample_Size)
            generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, query["obs"])
            obs_tvd, obs_kl, true_dist, fake_dist = match_with_true_dist(Exp, query["obs"], generated_labels_full, true_dist, feat, doPrint=False)  # get it from scm

            print(f'{intv_key}: tvd:{obs_tvd}, kl:{obs_kl} and tvd<={np.sqrt(0.5 * obs_kl)}')

            tvd_sum += obs_tvd * obs_dist[tuple(intv_key.values())]
            kl_sum += obs_kl * obs_dist[tuple(intv_key.values())]

            for key_dist in true_dist:
                if key_dist not in true_expected_dist:
                    true_expected_dist[key_dist]=0
                    fake_expected_dist[key_dist]=0

                true_expected_dist[key_dist]+= true_dist[key_dist]* (1/len(query["intervs"]))
                fake_expected_dist[key_dist]+= fake_dist[key_dist]* (1/len(query["intervs"]))



        print(f'--->Average tvd:{tvd_sum}, kl:{kl_sum} and tvd<={np.sqrt(0.5 * kl_sum)}')

        tvd_diff[query["expr"]].append(round(tvd_sum, 4))
        kl_diff[query["expr"]].append(round(kl_sum, 4))


    return tvd_diff, kl_diff, true_expected_dist, fake_expected_dist


def trainByCompEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        for query in Exp.interv_queries:

            for key in query["intervs"]:
                compare_Var= query["obs"]
                # for interv_no, key in enumerate(Exp.Data_intervs):


                if key!={}:
                    continue

                if len(compare_Var)==0:
                    continue

                intv_key = asKey(key)

                obs_indices = [Exp.label_names.index(lb) for lb in compare_Var]
                current_real_label = []
                if intv_key in dataset_dict:
                    current_real_label = dataset_dict[intv_key]["obs"][:, obs_indices].type(torch.LongTensor).view(-1,len(obs_indices)).to(Exp.DEVICE)

                query_str = getdoKey(compare_Var, dict(intv_key))

                fake_dist_dict= get_fake_distribution(Exp, label_generators, intv_key, compare_Var)

                true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str, load_scm=1)

                # dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
                #                                                          current_real_label.detach().cpu().numpy().astype(
                #                                                              int), "feature")

                obs_tvd = calculate_TVD(fake_dist_dict, true_dist_dict, doPrint=False)
                obs_kl= calculate_KL(fake_dist_dict, true_dist_dict, doPrint=False)

                if query_str in tvd_diff:
                    tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
                    kl_diff[query_str].append(round(obs_kl, 4))

        tvd_diff, kl_diff, _, _ = get_expected_loss_interventions(Exp, [],  label_generators, tvd_diff, kl_diff)


        save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)



    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:",  [round(val, 4) for val in tvd_diff[dist][ll:]])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff


Exp = Experiment("Exp1", set_trainGraph,
                 new_experiment=False,
                 features=["feature"],
                 Data_intervs=[{}])


plot_saved_results(Exp, None, [], epochs=150, delta=1)